import argparse
import os
import tensorflow as tf
from tensorflow import keras as K
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE


os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
parser = argparse.ArgumentParser()
parser.add_argument("--n_class", type=int, default=10)
parser.add_argument("--n_hid_pc", type=int, default=10)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=64)
args = parser.parse_args()


# data
mnist = K.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()   # [n, 28, 28], [n]
x_train, x_test = x_train / 255.0, x_test / 255.0
print("label max:", tf.reduce_max(y_test))
x_train = tf.reshape(x_train, [-1, 784])
x_test = tf.reshape(x_test, [-1, 784])
# x_train = tf.math.l2_normalize(x_train, axis=1)
# x_test = tf.math.l2_normalize(x_test, axis=1)
y_train = tf.one_hot(y_train, 10)
y_test = tf.one_hot(y_test, 10)
print("data:", type(x_train), x_train.shape, y_test.shape)

train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(args.batch_size)
test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).batch(args.batch_size)


def euclidean(A, B=None, sqrt=False):
    if (B is None) or (B is A):
        aTb = tf.matmul(A, tf.transpose(A))
        aTa = bTb = tf.linalg.diag_part(aTb)
    else:
        aTb = tf.matmul(A, tf.transpose(B))
        aTa = tf.linalg.diag_part(tf.matmul(A, tf.transpose(A)))
        bTb = tf.linalg.diag_part(tf.matmul(B, tf.transpose(B)))
    D = aTa[:, None] - 2.0 * aTb + bTb[None, :]
    D = tf.maximum(D, 0.0)
    if sqrt:
        mask = tf.cast(tf.equal(D, 0.0), "float32")
        D = D + mask * 1e-16
        D = tf.math.sqrt(D)
        D = D * (1.0 - mask)
    return D


def top_k_mask(D, k, rand_pick=False):
    """M[i][j] = 1 <=> D[i][j] is oen of the BIGGEST k in i-th row
    Args:
        D: (n, n), distance matrix
        k: param `k` of kNN
        rand_pick: true or false
            - if `True`, only ONE of the top-K element in each row will be selected randomly;
            - if `False`, ALL the top-K elements will be selected as usual.
    Ref:
        - https://cloud.tencent.com/developer/ask/196899
        - https://blog.csdn.net/HackerTom/article/details/103587415
    """
    n_row = tf.shape(D)[0]
    n_col = tf.shape(D)[1]

    k_val, k_idx = tf.math.top_k(D, k)
    if rand_pick:
        c_idx = tf.random_uniform([n_row, 1],
                                  minval=0, maxval=k,
                                  dtype="int32")
        r_idx = tf.range(n_row, dtype="int32")[:, None]
        idx = tf.concat([r_idx, c_idx], axis=1)
        k_idx = tf.gather_nd(k_idx, idx)[:, None]

    idx_offset = (tf.range(n_row) * n_col)[:, None]
    k_idx_linear = k_idx + idx_offset
    k_idx_flat = tf.reshape(k_idx_linear, [-1, 1])

    updates = tf.ones_like(k_idx_flat[:, 0], "int32")
    mask = tf.scatter_nd(k_idx_flat, updates, [n_row * n_col])
    mask = tf.reshape(mask, [-1, n_col])
    mask = tf.cast(mask, "float32")

    return mask


@tf.custom_gradient
def lvq(X, W):
    """X: [n, d]
    W: [m, d]
    """
    D = euclidean(X, W)  # [n, m]
    y = top_k_mask(- D, 1)  # [n, m], minus for nearest

    def grad(dy):
        # dy: [n, m]
        mask_sgn = tf.expand_dims((2 * dy - 1) * y, 2)  # [n, m, 1]
        X_minus_W = tf.expand_dims(X, 1) - tf.expand_dims(W, 0)  # [n, m, d]
        dW = tf.reduce_sum(X_minus_W * mask_sgn, 0)
        dW = dW / tf.maximum(1., tf.reduce_sum(mask_sgn, 0))
        return X, dW

    return y, grad


class LVQ(K.Model):
    def __init__(self, dim, n_class, n_hid_pc):
        super(LVQ, self).__init__()
        self.W = tf.Variable(tf.random.truncated_normal(
            [n_class*n_hid_pc, dim]))
        Q = np.zeros([n_hid_pc*n_class, n_class])
        for i in range(n_class):
            Q[i*n_class:(i+1)*n_class, i] = 1
        self.Q = tf.constant(Q, dtype="float32")

    def call(self, x):
        z = lvq(x, self.W)
        return tf.matmul(z, self.Q)


model = LVQ(x_train.shape[1], args.n_class, args.n_hid_pc)


class LVQ_Schedule(K.optimizers.schedules.LearningRateSchedule):
    """lr(n) = lr(0) * (1 - n / N)
    ref: https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py#L410-L511
    """
    def __init__(self, initial_learning_rate, n_iter):
        super(LVQ_Schedule, self).__init__()
        self.initial_learning_rate = tf.convert_to_tensor(
            initial_learning_rate, dtype="float32")
        self.n_iter = tf.convert_to_tensor(n_iter, dtype="float32")

    def __call__(self, step):
        return self.initial_learning_rate * (1. - step / self.n_iter)

    def get_config(self):
        return {
            "initial_learning_rate": self.initial_learning_rate,
            "n_iter": self.n_iter
        }


# optimizer = K.optimizers.SGD(
#     learning_rate=LVQ_Schedule(args.lr, args.epoch), momentum=0.9)
optimizer = K.optimizers.Adam()


train_loss = K.metrics.Mean(name='train_loss')
train_accuracy = K.metrics.CategoricalAccuracy(name='train_accuracy')

test_loss = K.metrics.Mean(name='test_loss')
test_accuracy = K.metrics.CategoricalAccuracy(name='test_accuracy')


#@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        pred = model(images)
        loss = tf.reduce_sum(tf.math.abs(labels - pred))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, pred)
    return loss


#@tf.function
def test_step(images, labels):
    pred = model(images)
    t_loss = tf.reduce_sum(tf.math.abs(labels - pred))

    test_loss(t_loss)
    test_accuracy(labels, pred)

    pred = tf.argmax(pred, axis=1)
    true = tf.argmax(labels, axis=1)  # tf.cast(labels, "int64")
    n_correct = tf.reduce_sum(tf.cast(pred == true, "float32"))
    return n_correct


loss_list, acc_list = [], []
for epoch in range(args.epoch):
    # 在下一个epoch开始时，重置评估指标
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        # images = tf.image.resize(images, [224, 224])
        # images = tf.tile(images, tf.constant([1, 1, 1, 3]))
        l = train_step(images, labels)
        loss_list.append(l.numpy())

    n_corr = 0
    for images, labels in test_ds:
        # images = tf.image.resize(images, [224, 224])
        # images = tf.tile(images, tf.constant([1, 1, 1, 3]))
        _n_corr = test_step(images, labels)
        n_corr += _n_corr.numpy()
    acc_list.append(n_corr / y_test.shape[0])

    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
                          train_loss.result(),
                          train_accuracy.result()*100,
                          test_loss.result(),
                          test_accuracy.result()*100))


# plot loss
fig = plt.figure()
plt.title("loss")
plt.plot(np.arange(len(loss_list)), loss_list)
# plt.show()
fig.savefig("loss.png")

# plot accuracy
fig = plt.figure()
plt.title("accuracy")
plt.plot(np.arange(len(acc_list)), acc_list)
# plt.show()
fig.savefig("accuracy.png")


# T-SNE
# fea_list = []
# for i, (images, labels) in enumerate(test_ds):
#     fea, _ = model(images)
#     fea_list.append(fea.numpy())
#     if i > 5:
#         break

# F = np.vstack(fea_list)
# tsne = TSNE(n_components=2, init="pca", random_state=0)
# F = tsne.fit_transform(F)
# x_min, x_max = np.min(F, 0), np.max(F, 0)
# F = (F - x_min) / (x_max - x_min)
F = model.W.numpy()
w_label = tf.repeat(tf.range(args.n_class), args.n_hid_pc).numpy()
fig = plt.figure()
plt.title("T-SNE")
for i in range(F.shape[0]):
    plt.text(F[i, 0], F[i, 1], str(w_label[i]),
             color=plt.cm.Set1(w_label[i] / 10.))
fig.savefig("tsne.png")
